{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get started with PyTorch Fully Sharded Data Parallel (FSDP2) and Ray Train\n",
    "\n",
    "**Time to complete:** 30 min\n",
    "\n",
    "This template shows how to get memory and performance improvements of integrating PyTorch's Fully Sharded Data Parallel with Ray Train. \n",
    "\n",
    "PyTorch's FSDP2 enables model sharding across nodes, allowing distributed training of large models with a significantly smaller memory footprint compared to standard Distributed Data Parallel (DDP). For a more detailed overview of FSDP2, see [PyTorch's official documentation](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html#getting-started-with-fully-sharded-data-parallel-fsdp2). \n",
    "\n",
    "This tutorial provides a comprehensive, step-by-step guide on integrating PyTorch FSDP2 with Ray Train. Specifically, this guide covers the following: \n",
    "\n",
    "- A hands-on example of training an image classification model\n",
    "- Configuring FSDP2 to mitigate out-of-memory (OOM) errors using mixed precision, CPU offloading, sharding granularity, and more\n",
    "- Model checkpoint saving and loading with PyTorch Distributed Checkpoint (DCP)\n",
    "- GPU memory profiling with PyTorch Profiler\n",
    "- Loading a distributed model for inference\n",
    "\n",
    "**Note:** This notebook uses FSDP2's `fully_sharded` API. If you're using FSDP1's `FullyShardedDataParallel`, consider migrating to FSDP2 for improved performance and features such as lower memory usage and `DTensor` integration. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div id=\"anyscale-note\" class=\"alert alert-block alert-warning\">\n",
    "\n",
    "  <strong>Anyscale Specific Configuration</strong>\n",
    "\n",
    "  <p><strong>Note:</strong> This tutorial is optimized for the Anyscale platform. When running on open source Ray, additional configuration is required. For example, you would need to manually:</p>\n",
    "\n",
    "  <ul>\n",
    "    <li><strong>Configure your Ray Cluster</strong>: Set up your multi-node environment and manage resource allocation without Anyscale's automation.</li>\n",
    "    <li><strong>Manage Dependencies</strong>: Manually install and manage dependencies on each node.</li>\n",
    "    <li><strong>Set Up Storage</strong>: Configure your own distributed or shared storage system for model checkpointing.</li>\n",
    "  </ul>\n",
    "</div>\n",
    "\n",
    "<style>\n",
    "  div#anyscale-note > p,\n",
    "  div#anyscale-note > ul,\n",
    "  div#anyscale-note > ul li {\n",
    "    color: black;\n",
    "  }\n",
    "\n",
    "  div#anyscale-note {\n",
    "    background-color: rgb(255, 243, 205);\n",
    "  }\n",
    "\n",
    "  div#anyscale-note {\n",
    "    border: 1px solid #ccc; \n",
    "    border-radius: 8px;\n",
    "    padding: 15px;\n",
    "  }\n",
    "\n",
    "</style>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example overview\n",
    "\n",
    "For demonstration purposes, this tutorial integrates Ray Train with FSDP2 using a **Vision Transformer (ViT)** trained on the FashionMNIST dataset. ViT was chosen because it has clear, repeatable block structures (transformer blocks) that are ideal for demonstrating FSDP2's sharding capabilities. \n",
    "\n",
    "While this example is relatively simple, FSDP's complexity can lead to common challenges during training, such as out-of-memory (OOM) errors. This guide addresses common issues by providing practical tips for improving performance and reducing memory utilization based on your specific use case. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Package and model setup\n",
    "\n",
    "Install the required dependencies for this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "pip install torch\n",
    "pip install torchvision\n",
    "pip install matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable Ray Train V2 for the latest train APIs\n",
    "import os\n",
    "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n",
    "\n",
    "# Profiling and utilities\n",
    "import torch.profiler\n",
    "import tempfile\n",
    "import uuid\n",
    "import logging\n",
    "\n",
    "# Set up logging\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition\n",
    "The following function initializes a Vision Transformer (ViT) model configured for the FashionMNIST dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Computer vision components\n",
    "from torchvision.models import VisionTransformer\n",
    "from torchvision.datasets import FashionMNIST\n",
    "from torchvision.transforms import ToTensor, Normalize, Compose\n",
    "\n",
    "def init_model() -> torch.nn.Module:\n",
    "    \"\"\"Initialize a Vision Transformer model for FashionMNIST classification.\n",
    "    \n",
    "    Returns:\n",
    "        torch.nn.Module: Configured ViT model\n",
    "    \"\"\"\n",
    "    logger.info(\"Initializing Vision Transformer model...\")\n",
    "\n",
    "    # Create a ViT model with architecture suitable for 28x28 images\n",
    "    model = VisionTransformer(\n",
    "        image_size=28,        # FashionMNIST image size\n",
    "        patch_size=7,         # Divide 28x28 into 4x4 patches of 7x7 pixels each\n",
    "        num_layers=10,        # Number of transformer encoder layers\n",
    "        num_heads=2,          # Number of attention heads per layer\n",
    "        hidden_dim=128,       # Hidden dimension size\n",
    "        mlp_dim=128,          # MLP dimension in transformer blocks\n",
    "        num_classes=10,       # FashionMNIST has 10 classes\n",
    "    )\n",
    "\n",
    "    # Modify the patch embedding layer for grayscale images (1 channel instead of 3)\n",
    "    model.conv_proj = torch.nn.Conv2d(\n",
    "        in_channels=1,        # FashionMNIST is grayscale (1 channel)\n",
    "        out_channels=128,     # Match the hidden_dim\n",
    "        kernel_size=7,        # Match patch_size\n",
    "        stride=7,             # Non-overlapping patches\n",
    "    )\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Define the training function\n",
    "\n",
    "Below is the main training function that orchestrates the FSDP2 training process. The following sections implement each of the helper functions used in this training loop. First, make the necessary imports for the training function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ray Train imports\n",
    "import ray\n",
    "import ray.train\n",
    "import ray.train.torch\n",
    "\n",
    "# PyTorch Core import\n",
    "import torch\n",
    "\n",
    "# PyTorch training components\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_func(config):\n",
    "    \"\"\"Main training function that integrates FSDP2 with Ray Train.\n",
    "    \n",
    "    Args:\n",
    "        config: Training configuration dictionary containing hyperparameters\n",
    "    \"\"\"\n",
    "    # Initialize the model\n",
    "    model = init_model()\n",
    "\n",
    "    # Configure device and move model to GPU\n",
    "    device = ray.train.torch.get_device()\n",
    "    torch.cuda.set_device(device)\n",
    "    model.to(device)\n",
    "\n",
    "    # Apply FSDP2 sharding to the model\n",
    "    shard_model(model)\n",
    "\n",
    "    # Initialize loss function and optimizer\n",
    "    criterion = CrossEntropyLoss()\n",
    "    optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))\n",
    "\n",
    "    # Load from checkpoint if available (for resuming training)\n",
    "    start_epoch = 0\n",
    "    loaded_checkpoint = ray.train.get_checkpoint()\n",
    "    if loaded_checkpoint:\n",
    "        latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)\n",
    "        start_epoch = latest_epoch + 1 if latest_epoch != None else 0\n",
    "        logger.info(f\"Resuming training from epoch {start_epoch}\")\n",
    "\n",
    "    # Prepare training data\n",
    "    transform = Compose([\n",
    "        ToTensor(), \n",
    "        Normalize((0.5,), (0.5,))\n",
    "    ])\n",
    "    data_dir = os.path.join(tempfile.gettempdir(), \"data\")\n",
    "    train_data = FashionMNIST(\n",
    "        root=data_dir, train=True, download=True, transform=transform\n",
    "    )\n",
    "    train_loader = DataLoader(\n",
    "        train_data, \n",
    "        batch_size=config.get('batch_size', 64), \n",
    "        shuffle=True\n",
    "    )\n",
    "    # Prepare data loader for distributed training\n",
    "    train_loader = ray.train.torch.prepare_data_loader(train_loader)\n",
    "\n",
    "    world_rank = ray.train.get_context().get_world_rank()\n",
    "\n",
    "    # Set up PyTorch Profiler for memory monitoring\n",
    "    with torch.profiler.profile(\n",
    "       activities=[\n",
    "           torch.profiler.ProfilerActivity.CPU,\n",
    "           torch.profiler.ProfilerActivity.CUDA,\n",
    "       ],\n",
    "       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),\n",
    "       record_shapes=True,\n",
    "       profile_memory=True,\n",
    "       with_stack=True,\n",
    "   ) as prof:\n",
    "\n",
    "        # Main training loop\n",
    "        running_loss = 0.0\n",
    "        num_batches = 0\n",
    "        epochs = config.get('epochs', 5)\n",
    "        \n",
    "        for epoch in range(start_epoch, epochs):\n",
    "            # Set epoch for distributed sampler to ensure proper shuffling\n",
    "            if ray.train.get_context().get_world_size() > 1:\n",
    "                train_loader.sampler.set_epoch(epoch)\n",
    "\n",
    "            for images, labels in train_loader:\n",
    "                # Note: prepare_data_loader automatically moves data to the correct device\n",
    "                outputs = model(images)\n",
    "                loss = criterion(outputs, labels)\n",
    "                \n",
    "                # Standard training step\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                \n",
    "                # Update profiler\n",
    "                prof.step()\n",
    "                \n",
    "                # Track metrics\n",
    "                running_loss += loss.item()\n",
    "                num_batches += 1\n",
    "\n",
    "            # Report metrics and save checkpoint after each epoch\n",
    "            avg_loss = running_loss / num_batches\n",
    "            metrics = {\"loss\": avg_loss}\n",
    "            report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)\n",
    "\n",
    "            # Log metrics from rank 0 only to avoid duplicate outputs\n",
    "            if world_rank == 0:\n",
    "                logger.info(metrics)\n",
    "    \n",
    "    # Export memory profiling results to cluster storage\n",
    "    run_name = ray.train.get_context().get_experiment_name()\n",
    "    prof.export_memory_timeline(\n",
    "        f\"/mnt/cluster_storage/{run_name}/rank{world_rank}_memory_profile.html\"\n",
    "    )\n",
    "\n",
    "    # Save the final model for inference\n",
    "    save_model_for_inference(model, world_rank)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Storage Configuration\n",
    "\n",
    "This demo uses cluster storage to allow for quick iteration and development, but this may not be suitable in production environments or at high scale. In those cases, you should use object storage instead. For more information about how to select your storage type, see the [Anyscale storage configuration docs](https://docs.anyscale.com/configuration/storage)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Model sharding with FSDP2\n",
    "\n",
    "PyTorch's `fully_shard` enables sharding at various granularities. At the most granular level, you can shard every layer to minimize peak memory utilization, but this also increases communication costs between Ray Train workers. Experiment with different sharding granularities to find the optimal balance for your use case. This example only shards the encoder blocks—the largest layers in the Vision Transformer.\n",
    "\n",
    "Beyond sharding granularity, FSDP2 offers several configuration options to optimize performance and mitigate OOM errors:\n",
    "\n",
    "### Device mesh configuration\n",
    "\n",
    "`init_device_mesh` configures a `DeviceMesh` that describes the training run's device topology. This example uses a simple 1D mesh for data parallelism, but `DeviceMesh` also supports multi-dimensional parallelism approaches including tensor parallelism and pipeline parallelism. In many cases, integrating several types of parallelism can further help to improve training performance.\n",
    "\n",
    "For more information about advanced multi-dimensional parallelism configurations, see the [PyTorch device mesh documentation](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html).\n",
    "\n",
    "### CPU offloading \n",
    "\n",
    "CPU offloading reduces GPU memory footprint by storing model components in the CPU. However, this comes with the trade-off of increased data transfer overhead between CPU and GPU during computation.\n",
    "\n",
    "**CPU offloading does the following:**\n",
    "- Stores sharded parameters, gradients, and optimizer states on CPU\n",
    "- Copies sharded parameters to GPU during forward/backward computation and frees them after use\n",
    "- Copies computed gradients to the CPU where PyTorch computes the optimizer step\n",
    "\n",
    "**When to use CPU offloading:**\n",
    "- When GPU memory is constrained\n",
    "- For very large models that don't fit in GPU memory\n",
    "\n",
    "**Don't use CPU offloading in the following cases:**\n",
    "- When CPU memory is limited (can cause CPU crashes due to out-of-memory error)\n",
    "- When training speed is more important than memory usage\n",
    "\n",
    "<div style=\"display: flex; gap: 40px; align-items: flex-start;\">\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>Without CPU offloading</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/gpu_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>With CPU offloading</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/cpu_offload_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "</div>\n",
    "Note: The above images are generated using PyTorch's Memory Profiler, which this tutorial covers later.\n",
    "\n",
    "It can be seen that CPU offloading significantly reduces the amount of GPU memory occupied by model parameters. \n",
    "\n",
    "Learn more about CPU offloading in the [PyTorch documentation](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.CPUOffloadPolicy).\n",
    "\n",
    "\n",
    "### `reshard_after_forward` flag \n",
    "`fully_shard` has a `reshard_after_forward` flag that enables all-gathered model weights to be freed immediately after the forward pass. This reduces peak GPU memory usage but increases the communication overhead between workers during the backward pass as parameters need to be all-gathered again. If unsharded model parameters are able to completely fit on each worker and don't pose a memory bottleneck, there's no need to enable `reshard_after_forward`.\n",
    "\n",
    "<div style=\"display: flex; gap: 40px; align-items: flex-start;\">\n",
    "  <div style=\"text-align: center;\"> \n",
    "    <h3><code>reshard_after_forward=False</code></h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/gpu_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3><code>reshard_after_forward=True</code></h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/reshard_after_forward_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "</div>\n",
    "\n",
    "With `reshard_after_forward=True`, the memory allocated to model parameters drops after the forward step whereas it peaks when `reshard_after_forward=False`.\n",
    "\n",
    "### Mixed precision\n",
    "\n",
    "Enabling mixed precision accelerates training and reduces GPU memory usage with minimal accuracy impact.\n",
    "\n",
    "**Benefits of mixed precision with FSDP2**\n",
    "- Reduced memory usage for activations and intermediate computations\n",
    "- Faster computation on modern GPUs\n",
    "- Maintained numerical stability through selective precision\n",
    "\n",
    "<div style=\"display: flex; gap: 40px; align-items: flex-start;\">\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>Without mixed precision</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/gpu_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>With mixed precision</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/mixed_precision_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "</div>\n",
    "\n",
    "With mixed precision enabled, the peak memory allocated to activations is halved.\n",
    "\n",
    "Learn more about mixed precision configuration on the [PyTorch documentation](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.MixedPrecisionPolicy).\n",
    "\n",
    "### Combining Memory Strategies\n",
    "\n",
    "The below diagram compares the GPU memory profile of default sharding to when all of the above strategies are enabled (CPU Offloading, Mixed Precision, `reshard_after_forward=True`).\n",
    "\n",
    "<div style=\"display: flex; gap: 40px; align-items: flex-start;\">\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>Default Sharding</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/gpu_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h4>Combined CPU Offloading, Mixed Precision, and Resharding</h4>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/all_strategies_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FSDP2 sharding imports \n",
    "from torch.distributed.fsdp import (\n",
    "    fully_shard,\n",
    "    FSDPModule,\n",
    "    CPUOffloadPolicy,\n",
    "    MixedPrecisionPolicy,\n",
    ")\n",
    "from torch.distributed.device_mesh import init_device_mesh "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shard_model(model: torch.nn.Module): \n",
    "    \"\"\"Apply FSDP2 sharding to the model with optimized configuration.\n",
    "    \n",
    "    Args:\n",
    "        model: The PyTorch model to shard\n",
    "    \"\"\"\n",
    "    logger.info(\"Applying FSDP2 sharding to model...\")\n",
    "\n",
    "    # Step 1: Create 1D device mesh for data parallel sharding\n",
    "    world_size = ray.train.get_context().get_world_size()\n",
    "    mesh = init_device_mesh(\n",
    "        device_type=\"cuda\", \n",
    "        mesh_shape=(world_size,), \n",
    "        mesh_dim_names=(\"data_parallel\",)\n",
    "    )\n",
    "\n",
    "    # Step 2: Configure CPU offloading policy (optional)\n",
    "    offload_policy = CPUOffloadPolicy()\n",
    "\n",
    "    # Step 3: Configure mixed precision policy (optional)\n",
    "    mp_policy = MixedPrecisionPolicy(\n",
    "        param_dtype=torch.float16,    # Store parameters in half precision\n",
    "        reduce_dtype=torch.float16,   # Use half precision for gradient reduction\n",
    "    )\n",
    "\n",
    "    # Step 4: Apply sharding to each transformer encoder block\n",
    "    for encoder_block in model.encoder.layers.children():\n",
    "        fully_shard(\n",
    "            encoder_block, \n",
    "            mesh=mesh, \n",
    "            reshard_after_forward=True,   # Free memory after forward pass\n",
    "            offload_policy=offload_policy, \n",
    "            mp_policy=mp_policy\n",
    "        )\n",
    "\n",
    "    # Step 5: Apply sharding to the root model\n",
    "    # This wraps the entire model and enables top-level FSDP2 functionality\n",
    "    fully_shard(\n",
    "        model, \n",
    "        mesh=mesh, \n",
    "        reshard_after_forward=True,   # Free memory after forward pass\n",
    "        offload_policy=offload_policy, \n",
    "        mp_policy=mp_policy\n",
    "    )\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Distributed Checkpointing\n",
    "This section sets up distributed checkpointing, loads a distributed model from a checkpoint, saves distributed model checkpoints, and saves a model for inference. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Distributed checkpoint wrapper setup\n",
    "\n",
    "This section creates a checkpointing wrapper using PyTorch's `Stateful` API to simplify distributed checkpoint management. From the PyTorch docs, this basic wrapper handles the complexities of saving and loading FSDP2 model states across multiple workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch Distributed Checkpoint (DCP) imports\n",
    "from torch.distributed.checkpoint.state_dict import (\n",
    "    get_state_dict,\n",
    "    set_state_dict,\n",
    "    get_model_state_dict,\n",
    "    StateDictOptions\n",
    ")\n",
    "from torch.distributed.checkpoint.stateful import Stateful"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AppState(Stateful):\n",
    "    \"\"\"This is a useful wrapper for checkpointing the Application State. Because this object is compliant\n",
    "    with the Stateful protocol, PyTorch DCP automatically calls state_dict/load_state_dict as needed in the\n",
    "    dcp.save/load APIs.\n",
    "\n",
    "    Note: This wrapper is used to handle calling distributed state dict methods on the model\n",
    "    and optimizer.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model, optimizer=None, epoch=None):\n",
    "        self.model = model\n",
    "        self.optimizer = optimizer\n",
    "        self.epoch = epoch\n",
    "\n",
    "    def state_dict(self):\n",
    "        # this line automatically manages FSDP2 FQN's (Fully Qualified Name), as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT\n",
    "        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)\n",
    "        return {\n",
    "            \"model\": model_state_dict,\n",
    "            \"optim\": optimizer_state_dict,\n",
    "            \"epoch\": self.epoch\n",
    "        }\n",
    "\n",
    "    def load_state_dict(self, state_dict):\n",
    "        # sets our state dicts on the model and optimizer, now that loading is complete\n",
    "        set_state_dict(\n",
    "            self.model,\n",
    "            self.optimizer,\n",
    "            model_state_dict=state_dict[\"model\"],\n",
    "            optim_state_dict=state_dict[\"optim\"],\n",
    "        )\n",
    "        # Load epoch information if available\n",
    "        if \"epoch\" in state_dict:\n",
    "            self.epoch = state_dict[\"epoch\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load distributed model from checkpoint\n",
    "\n",
    "Load distributed checkpoints using `dcp.load`, which automatically handles resharding when the number of workers changes between training runs. This flexibility allows you to resume training with different resource configurations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PyTorch Distributed Checkpoint (DCP) Core import\n",
    "import torch.distributed.checkpoint as dcp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_fsdp_checkpoint(model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint) -> int | None:\n",
    "    \"\"\"Load an FSDP checkpoint into the model and optimizer.\n",
    "    \n",
    "    This function handles distributed checkpoint loading with automatic resharding\n",
    "    support. It can restore checkpoints even when the number of workers differs\n",
    "    from the original training run.\n",
    "    \n",
    "    Args:\n",
    "        model: The FSDP-wrapped model to load state into\n",
    "        optimizer: The optimizer to load state into\n",
    "        ckpt: Ray Train checkpoint containing the saved state\n",
    "\n",
    "    Returns:\n",
    "        int: The epoch number saved within the checkpoint.\n",
    "    \"\"\"\n",
    "    logger.info(\"Loading distributed checkpoint for resuming training...\")\n",
    "    \n",
    "    try:\n",
    "        with ckpt.as_directory() as checkpoint_dir:\n",
    "            # Create state wrapper for DCP loading\n",
    "            app_state = AppState(model, optimizer)\n",
    "            state_dict = {\"app\": app_state}\n",
    "            \n",
    "            # Load the distributed checkpoint\n",
    "            dcp.load(\n",
    "                state_dict=state_dict,\n",
    "                checkpoint_id=checkpoint_dir\n",
    "            )\n",
    "            \n",
    "        logger.info(f\"Successfully loaded distributed checkpoint from epoch {app_state.epoch}\")\n",
    "        return app_state.epoch\n",
    "    except Exception as e:\n",
    "        logger.error(f\"Failed to load checkpoint: {e}\")\n",
    "        raise RuntimeError(f\"Checkpoint loading failed: {e}\") from e"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save model checkpoints\n",
    "\n",
    "The following function handles periodic checkpoint saving during training, combining metrics reporting with distributed checkpoint storage:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def report_metrics_and_save_fsdp_checkpoint(\n",
    "    model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int = 0\n",
    ") -> None:\n",
    "    \"\"\"Report training metrics and save an FSDP checkpoint.\n",
    "    \n",
    "    This function performs two critical operations:\n",
    "    1. Saves the current model and optimizer state using distributed checkpointing\n",
    "    2. Reports metrics to Ray Train for tracking\n",
    "    \n",
    "    Args:\n",
    "        model: The FSDP-wrapped model to checkpoint\n",
    "        optimizer: The optimizer to checkpoint\n",
    "        metrics: Dictionary of metrics to report (e.g., loss, accuracy)\n",
    "        epoch: The current epoch to be saved\n",
    "    \"\"\"\n",
    "    logger.info(\"Saving checkpoint and reporting metrics...\")\n",
    "    \n",
    "    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
    "        # Perform a distributed checkpoint with DCP\n",
    "        state_dict = {\"app\": AppState(model, optimizer, epoch)}\n",
    "        dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)\n",
    "\n",
    "        # Report each checkpoint shard from all workers\n",
    "        # This saves the checkpoint to shared cluster storage for persistence\n",
    "        checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)\n",
    "        ray.train.report(metrics, checkpoint=checkpoint)\n",
    "        \n",
    "    logger.info(f\"Checkpoint saved successfully. Metrics: {metrics}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the model for inference\n",
    "\n",
    "After training, it is often useful to consolidate sharded checkpoints into a single file for convenient sharing or inference. Unlike regular distributed checkpointing, this process produces a large artifact compatible with torch.load. To do so, the `get_model_state_dict` function all-gathers parameter shards to rank 0, reconstructs the full state dict, and then saves the consolidated checkpoint to cluster storage.\n",
    "\n",
    "Note that a key limitation of this approach is that the entire model must be materialized in memory on rank 0. For large models, this can exceed the available CPU RAM and result in out-of-memory errors. In such cases, it is advised to keep the model in its sharded format and rely on distributed model loading for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_model_for_inference(model: FSDPModule, world_rank: int) -> None:\n",
    "    \"\"\"Save the complete unsharded model for inference.\n",
    "    \n",
    "    This function consolidates the distributed model weights into a single\n",
    "    checkpoint file that can be used for inference without FSDP.\n",
    "    \n",
    "    Args:\n",
    "        model: The FSDP2-wrapped model to save\n",
    "        world_rank: The rank of the current worker\n",
    "    \"\"\"\n",
    "    logger.info(\"Preparing model for inference...\")\n",
    "    \n",
    "    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
    "        save_file = os.path.join(temp_checkpoint_dir, \"full-model.pt\")\n",
    "\n",
    "        # Step 1: All-gather the model state across all ranks\n",
    "        # This reconstructs the complete model from distributed shards\n",
    "        model_state_dict = get_model_state_dict(\n",
    "            model=model,\n",
    "            options=StateDictOptions(\n",
    "                full_state_dict=True,    # Reconstruct full model\n",
    "                cpu_offload=True,        # Move to CPU to save GPU memory\n",
    "            )\n",
    "        )\n",
    "\n",
    "        logger.info(\"Successfully retrieved complete model state dict\")\n",
    "        checkpoint = None\n",
    "\n",
    "        # Step 2: Save the complete model (rank 0 only)\n",
    "        if world_rank == 0: \n",
    "            torch.save(model_state_dict, save_file)\n",
    "            logger.info(f\"Saved complete model to {save_file}\")\n",
    "\n",
    "            # Create checkpoint for shared storage\n",
    "            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)\n",
    "\n",
    "        # Step 3: Report the final checkpoint to Ray Train\n",
    "        ray.train.report(\n",
    "            {}, \n",
    "            checkpoint=checkpoint, \n",
    "            checkpoint_dir_name=\"full_model\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Launching the distributed training job\n",
    "\n",
    "This section configures and launches the distributed training job using Ray Train's TorchTrainer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure distributed training resources\n",
    "scaling_config = ray.train.ScalingConfig(\n",
    "    num_workers=2,      # Number of distributed workers\n",
    "    use_gpu=True        # Enable GPU training\n",
    ")\n",
    "\n",
    "# Configure training parameters\n",
    "train_loop_config = {\n",
    "    \"epochs\": 5,\n",
    "    \"learning_rate\": 0.001,\n",
    "    \"batch_size\": 64,\n",
    "}\n",
    "\n",
    "# Create experiment name\n",
    "experiment_name=f\"fsdp_mnist_{uuid.uuid4().hex[:8]}\"\n",
    "\n",
    "# Configure run settings and storage\n",
    "run_config = ray.train.RunConfig(\n",
    "    # Persistent storage path accessible across all worker nodes\n",
    "    storage_path=\"/mnt/cluster_storage/\",\n",
    "    # Unique experiment name (use consistent name to resume from checkpoints)\n",
    "    name=experiment_name,\n",
    "    # Fault tolerance configuration\n",
    "    failure_config=ray.train.FailureConfig(max_failures=1),\n",
    ")\n",
    "\n",
    "# Initialize and launch the distributed training job\n",
    "trainer = ray.train.torch.TorchTrainer(\n",
    "    train_loop_per_worker=train_func,\n",
    "    scaling_config=scaling_config,\n",
    "    train_loop_config=train_loop_config,\n",
    "    run_config=run_config,\n",
    ")\n",
    "\n",
    "print(\"Starting FSDP2 training job...\")\n",
    "result = trainer.fit()\n",
    "print(\"Training completed successfully!\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPU memory profiling\n",
    "\n",
    "GPU memory profiling is a useful tool for monitoring and analyzing memory usage during model training. It helps identify bottlenecks, optimize resource allocation, and prevent OOM errors. PyTorch's GPU Memory Profiler is configured within the training function.\n",
    "\n",
    "In this demo, the profiler is configured to generate a profiling file for each worker accessible from cluster storage under the Anyscale Files tab. To inspect a worker's memory profile, download the corresponding HTML file and open it in your browser. The profiler configuration and export path can be customized within the training function.  For more details on PyTorch's memory profiler, see the [PyTorch blog](https://pytorch.org/blog/understanding-gpu-memory-1/).\n",
    "\n",
    "<div style=\"display: flex; gap: 40px; align-items: flex-start;\">\n",
    "  <div style=\"text-align: center;\">\n",
    "    <h3>Example memory profile</h3>\n",
    "    <img src=\"https://raw.githubusercontent.com/ray-project/ray/master/doc/source/train/examples/pytorch/pytorch-fsdp/images/gpu_memory_profile.png\" width=\"600\"/>\n",
    "  </div>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Post training directory view\n",
    "The Anyscale platform saves the checkpoint shards, full model, and memory profiling reports in cluster storage with the following layout:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "/mnt/cluster_storage/fsdp_mnist_1/\n",
    "├── checkpoint_1/\n",
    "│ ├── __0_0.distcp                  # Shard file for rank 0\n",
    "│ └── __1_0.distcp                  # Shard file for rank 1\n",
    "├── checkpoint_2/\n",
    "│ └── ... (similar structure)\n",
    "├── checkpoint_3/\n",
    "│ └── ... (similar structure)\n",
    "├── ... # Additional checkpoints\n",
    "├── full_model/\n",
    "│ └── full_model.pt                 # Full model checkpoint (for inference/deployment)\n",
    "├── checkpoint_manager_snapshot.json\n",
    "├── rank0_memory_profile.html       # Memory profiling for rank 0\n",
    "└── rank1_memory_profile.html       # Memory profiling for rank 1\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the trained model for inference\n",
    "\n",
    "After training completes, you can load the saved model for inference on new data. Ray Train loads the model in its unsharded form, ready for standard PyTorch inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update this path to match your trained model location\n",
    "# The path follows the pattern: /mnt/cluster_storage/{experiment_name}/full_model/full-model.pt\n",
    "PATH_TO_FULL_MODEL = f\"/mnt/cluster_storage/{experiment_name}/full_model/full-model.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the same model architecture for inference\n",
    "model = init_model()\n",
    "\n",
    "# Load the trained weights \n",
    "state_dict = torch.load(PATH_TO_FULL_MODEL, map_location='cpu')\n",
    "model.load_state_dict(state_dict)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the test data\n",
    "transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])\n",
    "test_data = FashionMNIST(\n",
    "    root=\".\", train=False, download=True, transform=transform\n",
    ")\n",
    "test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predicted_label=8 test_label=9\n"
     ]
    }
   ],
   "source": [
    "# Test model inference\n",
    "with torch.no_grad():\n",
    "    out = model(test_data.data[0].reshape(1, 1, 28, 28).float())\n",
    "    predicted_label = out.argmax().item()\n",
    "    test_label = test_data.targets[0].item()\n",
    "    print(f\"{predicted_label=} {test_label=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, you did the following: \n",
    "\n",
    "- Trained an image classification model using FSDP2 and Ray Train\n",
    "- Learned how to load and save distributed checkpoints with PyTorch DCP\n",
    "- Gained insight on configuring FSDP2 to balance training performance and memory usage\n",
    "- Unlocked multi-node GPU memory observability with PyTorch Memory Profiler"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
